Re-enable time-dependent z-scoring for Flow Matching#1752
Re-enable time-dependent z-scoring for Flow Matching#1752satwiksps wants to merge 7 commits intosbi-dev:mainfrom
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1752 +/- ##
==========================================
- Coverage 88.54% 88.07% -0.48%
==========================================
Files 137 137
Lines 11515 12258 +743
==========================================
+ Hits 10196 10796 +600
- Misses 1319 1462 +143
Flags with carried forward coverage won't be shown. Click here to find out more.
|
|
It seems Since this failure is in The actual Flow Matching benchmarks and integration tests for this PR passed successfully though |
this was an old bug that surfaced now likely because codecov was trying to serialize things.
Yes, this is unrelated and popped up here by chance or because of an unrelated change in a downstream package. I pushed a fix to this branch ✅ |
|
Thanks for working on this @satwiksps ! Overall, this looks exactly right. However, after reviewing the code and tracing through the flow matching implementation, I believe the z-scoring formula is inverted relative to the interpolation convention (quite confusing!) The interpolation in the loss function is:
So the expected input mean at each time is:
Current PR formula:
This gives mu_t = 0 at t=0 and mu_t = mean_data at t=1 — exactly backwards. Correct formula should be:
The formula only matches at t=0.5 and is maximally wrong at the boundaries. Note on To verify this, I suggest the following test: The standard linear Gaussian test, but with uniform prior between 95 and 100, and with data Can you confirm this (maybe I got confused with the integration directions after all)? |
|
|
||
| # call the network to get the estimated vector field | ||
| v = self.net(input, condition_emb, time) | ||
| t_view = time.view(-1, *([1] * (input.ndim - 1))) |
There was a problem hiding this comment.
Assuming a Gaussian target at t=1 with the here given mu1 and std1 the exact marginal velocity would have follow form:
# ---- marginal Gaussian stats (alpha=t, sigma=1-t, diag C = s1^2) ----
mu_t = t_view * m # \bar{mu}_t
var_t = (t_view.square() * s1_sq) + one_minus_t.square() # diag(S_t)
std_t = var_t.sqrt().clamp_min(self.eps)
# ---- z-scoreing-scaling for net (as currently) ----
x_centered = x - mu_t
x_norm = x_centered / std_t # c_in * (x - mu_t)
resid_norm = self.net(x_norm, condition_emb, t) # f_theta(...)
resid = resid_norm * std_t # c_out * f_theta
# ---- Gaussian posterior mean E[x1 | xt=x] under diag prior ----
# k_t = alpha * C / S_t with alpha=t and C=s1^2 (diagonal)
k_t = (t_view * s1_sq) / var_t
x1_hat = m + k_t * x_centered # m + k_t (x - t m)
# ---- Gaussian affine baseline: a(t)=t, b(t)=1-t ----
u_gauss = (t_view * x) + (one_minus_t * x1_hat)Although this is only with respect to the "prior" (i.e. not the posterior). But might still be reasonable.
manuelgloeckler
left a comment
There was a problem hiding this comment.
Hey @satwiksps !
Thanks for the contribution! I checked with main and as of now it does I guess on average perform very similar if not a bit worse than before (although, I think thats mostly fine i.e. these tasks).
I wonder if it would make sense to improve the "preconditioning" a bit more (see comments).
Thanks for adding the comparison to |
|
Alright, I looked at it again and I realized that my proposal was actually incorrect. The formulas I proposed would result in total normalization, i.e., "independent" z-scoring, where all time steps have equal zero mean after z-scoring and we lose valuable time-depenedent information - sorry @satwiksps , your formulas where actually correct! What Manuel proposed is great, we z-score with respect to the Gaussian baseline, e.g., what one would expect when the posterior is actually Gaussian. Then the flow matching network only has to learn the residual from this ideal baseline (please correct me @manuelgloeckler if this intuition is inaccurate). I tested this locally with the following setup:
Results:
Thus, @satwiksps I suggest you implement both options, your proposal and Manuel's proposal and add the test as a new z-scoring test and confirm the results. |
|
@janfb The preconditioning is with respect to the "prior" not the posterior (as this would require regression from x). I don't think that it will "hurt" in almost all cases i.e. FM nets are initialized to output zero hence effectively will let the initialized network sample from a mass covering Gaussian approximation of the prior (and everything else needs to be learned). Nonetheless having an option to disable it is always good. Agree that the benchmark tests are not really sensitive to the z-scoreing, but as we usually enable z-scoreing by default it shouldn't hurt performance even if its not necessary. But as said the deviation is small enough to be fine (and might improve with the additional baseline). |

Description
This PR re-introduces z-scoring for Flow Matching estimators using a time-dependent normalization approach.
As discussed in #1623, standard z-scoring at$t=0$ is problematic because the network input is noise, not data. This implementation interpolates the normalization statistics based on the time step $t$ , ensuring the network always receives inputs with standard statistics:
Related Issues/PRs
Changes
sbi/neural_nets/net_builders/vector_field_nets.py: Updatedbuild_vector_field_estimatorto calculate the training data statistics (meanandstd) and pass them to the estimator.sbi/neural_nets/estimators/flowmatching_estimator.py:mean_1andstd_1as buffers (initialized as floats to ensure type consistency).forward()to apply the time-dependent z-scoring formula.tests/linearGaussian_vector_field_test.py: Added a new integration test (test_fmpe_time_dependent_z_scoring_integration) to verify that statistics are correctly learned and the forward pass executes without errors.Verification
mean_1andstd_1are populated and theode_fnruns correctly.sbibenchmarks locally (pytest --bm --bm-mode fmpe) to check for stability and performance. All 12 tests passed successfully (screenshot attached below).